#include "ns3/boolean.h"
#include "ns3/command-line.h"
#include "ns3/config.h"
#include "ns3/double.h"
#include "ns3/eht-phy.h"
#include "ns3/enum.h"
#include "ns3/internet-stack-helper.h"
#include "ns3/ipv4-address-helper.h"
#include "ns3/log.h"
#include "ns3/mobility-helper.h"
#include "ns3/multi-model-spectrum-channel.h"
#include "ns3/on-off-helper.h"
#include "ns3/packet-sink-helper.h"
#include "ns3/packet-sink.h"
#include "ns3/rng-seed-manager.h"
#include "ns3/spectrum-wifi-helper.h"
#include "ns3/ssid.h"
#include "ns3/string.h"
#include "ns3/udp-client-server-helper.h"
#include "ns3/uinteger.h"
#include "ns3/wifi-acknowledgment.h"
#include "ns3/yans-wifi-channel.h"
#include "ns3/yans-wifi-helper.h"

#include <array>
#include <functional>
#include <numeric>

using namespace ns3;

NS_LOG_COMPONENT_DEFINE("EhtStaApLinkLink");

/**
 * \param udp true if UDP is used, false if TCP is used
 * \param serverApp a container of server applications
 * \param payloadSize the size in bytes of the packets
 * \return the bytes received by each server application
 */
std::vector<uint64_t>
StaApLinkGetRxBytes(bool udp, const ApplicationContainer& serverApp, uint32_t payloadSize)
{
    std::vector<uint64_t> rxBytes(serverApp.GetN(), 0);
    if (udp)
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = payloadSize * DynamicCast<UdpServer>(serverApp.Get(i))->GetReceived();
        }
    }
    else
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = DynamicCast<PacketSink>(serverApp.Get(i))->GetTotalRx();
        }
    }
    return rxBytes;
};

int
main(int argc, char* argv[])
{
    double simulationTime = 3; // seconds
    double distance = 1.0;      // meters

    int mcs = 13;
    int gi = 3200;              // ns
    uint32_t payloadSize = 700; // bytes
    // uint64_t udpMaxPacketNumber = 4294967295U;
    // udpMaxPacketNumber = 4294967295U;

    uint8_t nLinks = 1;
    std::string channelLink0 =
        "{36, 20, BAND_5GHZ, 0}"; // {channel number, channel width (MHz), PHY band, primary20
                                  // index} from wifi-phy.cc
    // std::string channelLink1 =
    //     "{40, 20, BAND_5GHZ, 0}"; // {channel number, channel width (MHz), PHY band, primary20
    //                               // index} from wifi-phy.cc

    NodeContainer staNodes;
    staNodes.Create(1);
    NodeContainer apNode;
    apNode.Create(1);

    NetDeviceContainer staDevices;
    NetDeviceContainer apDevice;
    WifiMacHelper wifiMac;
    WifiHelper wifiHelper;

    wifiHelper.SetStandard(WIFI_STANDARD_80211be);

    for (uint8_t i = 0; i < nLinks; i++)
    {
        uint64_t nonHtRefRateMbps = EhtPhy::GetNonHtReferenceRate(mcs) / 1e6;
        std::string ctrlRateStr = "OfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
        wifiHelper.SetRemoteStationManager(i,
                                           "ns3::ConstantRateWifiManager",
                                           "DataMode",
                                           StringValue("EhtMcs" + std::to_string(mcs)),
                                           "ControlMode",
                                           StringValue(ctrlRateStr));
    }

    Ssid ssid = Ssid("ns3-80211be");

    Ptr<MultiModelSpectrumChannel> spectrumChannel = CreateObject<MultiModelSpectrumChannel>();
    Ptr<LogDistancePropagationLossModel> lossModel =
        CreateObject<LogDistancePropagationLossModel>();
    spectrumChannel->AddPropagationLossModel(lossModel);

    SpectrumWifiPhyHelper phy(nLinks);
    phy.SetPcapDataLinkType(WifiPhyHelper::DLT_IEEE802_11_RADIO);
    phy.SetChannel(spectrumChannel);

    wifiMac.SetType("ns3::StaWifiMac", "Ssid", SsidValue(ssid));
    phy.Set(0, "ChannelSettings", StringValue(channelLink0));
    // phy.Set(1, "ChannelSettings", StringValue(channelLink1));
    staDevices = wifiHelper.Install(phy, wifiMac, staNodes);

    // 考虑到仅仅只有一个STA，因此就不使用多用户ACK机制

    wifiMac.SetType("ns3::ApWifiMac",
                    "Ssid",
                    SsidValue(ssid),
                    "EnableBeaconJitter",
                    BooleanValue(false));
    apDevice = wifiHelper.Install(phy, wifiMac, apNode);

    RngSeedManager::SetSeed(1);
    RngSeedManager::SetRun(1);
    int64_t streamNumber = 100;
    streamNumber += wifiHelper.AssignStreams(apDevice, streamNumber);
    streamNumber += wifiHelper.AssignStreams(
        staDevices,
        streamNumber); // 既为不同种类的device分配了不同的随机数种子，又保证了在不同的仿真实验中，随机数种子相同，从而保证了实验的可重复性。

    // Set guard interval and MPDU buffer size
    Config::Set("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/HeConfiguration/GuardInterval",
                TimeValue(NanoSeconds(gi))); // 保护间隔，即每个OFDM符号之间的间隔
    
    // mobility.
    // 设置移动模型，即设置移动节点的位置，此处仅是静态模型
    MobilityHelper mobility;
    Ptr<ListPositionAllocator> positionAlloc = CreateObject<ListPositionAllocator>();

    positionAlloc->Add(Vector(0.0, 0.0, 0.0));
    positionAlloc->Add(Vector(distance, 0.0, 0.0));
    mobility.SetPositionAllocator(positionAlloc);

    mobility.SetMobilityModel("ns3::ConstantPositionMobilityModel");

    mobility.Install(apNode);
    mobility.Install(staNodes);

    /* Internet stack*/
    InternetStackHelper stack;
    stack.Install(apNode);
    stack.Install(staNodes);

    Ipv4AddressHelper address;
    address.SetBase("192.168.1.0", "255.255.255.0");
    Ipv4InterfaceContainer staNodeInterfaces;
    Ipv4InterfaceContainer apNodeInterface;

    staNodeInterfaces = address.Assign(staDevices);
    apNodeInterface = address.Assign(apDevice);

    /* Setting applications */
    ApplicationContainer serverApp;
    auto serverNodes = staNodes;
    Ipv4InterfaceContainer serverInterfaces;
    NodeContainer clientNodes;

    serverInterfaces.Add(staNodeInterfaces.Get(0));
    clientNodes.Add(apNode.Get(0));

    for (uint16_t port = 9; port < 19; port++)
    {
        // uint16_t port = 9;
        UdpServerHelper server(port);
        serverApp = server.Install(serverNodes.Get(0));
        serverApp.Start(Seconds(0.0));
        serverApp.Stop(Seconds(simulationTime + 1));

        UdpClientHelper client(serverInterfaces.GetAddress(0), port);
        client.SetAttribute("MaxPackets", UintegerValue(4294967295U));
        client.SetAttribute("Interval", TimeValue(Time("0.00001"))); // packets/s
        client.SetAttribute("PacketSize", UintegerValue(payloadSize));
        ApplicationContainer clientApp = client.Install(clientNodes.Get(0));
        clientApp.Start(Seconds(1.0));
        clientApp.Stop(Seconds(simulationTime + 1));
    }

    // 打印udp服务器端及客户端的mac地址和ip地址
    std::cout << "Server IP address: " << serverInterfaces.GetAddress(0) << std::endl;
    std::cout << "Server MAC address: " << staDevices.Get(0)->GetAddress() << std::endl;
    std::cout << "Client IP address: " << apNodeInterface.GetAddress(0) << std::endl;
    std::cout << "Client MAC address: " << apDevice.Get(0)->GetAddress() << std::endl;

    std::cout << "Number of apDevice " << apDevice.GetN() << std::endl; 
    std::cout << "Number of staDevices " << staDevices.GetN() << std::endl;

    std::string pcapFileName = "EhtStaApLink - ap";
    phy.EnablePcap(pcapFileName, apDevice, true);
    std::string pcapFileName1 = "EhtStaApLink - sta";
    phy.EnablePcap(pcapFileName1, staDevices, true);

    std::cout << "MCS value"
              << "\t\t"
              << "Channel width"
              << "\t\t"
              << "GI"
              << "\t\t\t"
              << "Throughput" << '\n';

    Simulator::Stop(Seconds(simulationTime + 1));
    Simulator::Run();

    // cumulative number of bytes received by each server application
    std::vector<uint64_t> cumulRxBytes(1, 0);
    cumulRxBytes = StaApLinkGetRxBytes(true, serverApp, payloadSize);
    uint64_t rxBytes =
        std::accumulate(cumulRxBytes.cbegin(), cumulRxBytes.cend(), 0); // 向量中元素累加
    double throughput = (rxBytes * 8) / (simulationTime * 1000000.0);   // Mbit/s

    Simulator::Destroy();
    std::cout << mcs << "\t\t\t" << "20 MHz\t\t\t" << gi << " ns\t\t\t" << throughput
              << " Mbit/s" << std::endl;

    return 0;
}